import logging
from pprint import pformat
from typing import Any, Dict, List, Sequence, Set

# common imports
from common.event_queue import IAgentEventPublisher
from common.types import AgentID, Event, NetworkPort
from common.utils.code_utils import del_key

# dependencies to get rid of or internalize
from infection_monkey.exploit import IAgentOTPProvider, IHTTPAgentBinaryServerRegistrar
from infection_monkey.exploit.tools import filter_out_closed_ports, get_open_http_ports
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.network import TCPPortSelector

from .log4shell_exploit_client import Log4ShellExploitClient
from .log4shell_exploiter import Log4ShellExploiter
from .log4shell_options import Log4ShellOptions

logger = logging.getLogger(__name__)


def get_ports_to_try(
    host: TargetHost, target_ports: List[NetworkPort], try_all_discovered_http_ports: bool
) -> Set[NetworkPort]:
    ports_to_try = filter_out_closed_ports(host, target_ports)
    if try_all_discovered_http_ports:
        ports_to_try.update(get_open_http_ports(host))

    return ports_to_try


def should_attempt_exploit(ports_to_try: Set[NetworkPort]) -> bool:
    return bool(ports_to_try)


class Plugin:
    def __init__(
        self,
        *,
        plugin_name: str,
        agent_id: AgentID,
        agent_event_publisher: IAgentEventPublisher,
        http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
        tcp_port_selector: TCPPortSelector,
        otp_provider: IAgentOTPProvider,
        **kwargs,
    ):
        log4shell_exploit_client = Log4ShellExploitClient(
            plugin_name, agent_id, agent_event_publisher
        )

        self._log4shell_exploiter = Log4ShellExploiter(
            agent_id=agent_id,
            log4shell_exploit_client=log4shell_exploit_client,
            tcp_port_selector=tcp_port_selector,
            http_agent_binary_server_registrar=http_agent_binary_server_registrar,
            otp_provider=otp_provider,
        )

    def run(
        self,
        *,
        host: TargetHost,
        servers: Sequence[str],
        current_depth: int,
        options: Dict[str, Any],
        interrupt: Event,
        **kwargs,
    ) -> ExploiterResult:
        # HTTP ports options are hack because they are needed in fingerprinters
        del_key(options, "http_ports")

        try:
            logger.debug(f"Parsing options: {pformat(options)}")
            log4shell_options = Log4ShellOptions(**options)
        except Exception as err:
            msg = f"Failed to parse Log4Shell options: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        self._ports_to_try = get_ports_to_try(
            host, log4shell_options.target_ports, log4shell_options.try_all_discovered_http_ports
        )
        if not should_attempt_exploit(self._ports_to_try):
            msg = f"Could not find any open web ports to exploit on host {host.ip}"
            logger.debug(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )

        try:
            logger.debug(f"Running Log4Shell exploiter on host {host.ip}")
            return self._log4shell_exploiter.exploit_host(
                target_host=host,
                ports_to_try=self._ports_to_try,
                servers=servers,
                current_depth=current_depth,
                options=log4shell_options,
                interrupt=interrupt,
            )
        except Exception as err:
            msg = f"An unexpected exception occurred while attempting to exploit host: {err}"
            logger.exception(msg)
            return ExploiterResult(error_message=msg)
